/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_

#include <vector>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include <memory>

#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "operator/ops.h"
#include "utils/symbolic.h"

namespace mindspore {
namespace opt {
namespace irpass {
namespace internal {
class EnvGetitemTransform {
 public:
  EnvGetitemTransform() : cache_() {}
  ~EnvGetitemTransform() = default;

  FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
    if (cache_.find(fg) == cache_.end()) {
      cache_[fg] = {};
    }

    auto &cache = cache_[fg];
    auto hash_key = std::make_pair(key, default_node);
    if (cache.find(hash_key) == cache.end()) {
      std::ostringstream ss("env", std::ostringstream::app);
      if (key->node() != nullptr) {
        ss << key->node()->ToString();
      }

      auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
      auto env = new_fg->output();
      while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
        // {prim::kPrimEnvSetItem, env, symbolickey, value}
        auto &inputs = env->cast<CNodePtr>()->inputs();
        if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
          MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance.";
        }

        env = inputs[1];
        auto value = inputs[3];
        auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
        if (*key2 == *key) {
          new_fg->set_output(value);
          cache[hash_key] = new_fg;
          cache_[fg] = cache;
          return new_fg;
        }
      }
      new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
      cache[hash_key] = new_fg;
    }

    return cache[hash_key];
  }

 private:
  std::unordered_map<FuncGraphPtr,
                     std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
    cache_;
};
}  // namespace internal

// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
class NewEnvGetItem : public AnfVisitor {
 public:
  AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
    Reset();
    auto gety = [this](const AnfNodePtr &node) -> bool {
      this->y_ = node;
      return true;
    };

    AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode<EnvInstance>, IsVNode, gety})(node);
    if (env_ != nullptr && env_->Len() == 0) {
      return y_;
    }
    return nullptr;
  }

  void Visit(const ValueNodePtr &vnode) override {
    if (env_ == nullptr) {
      env_ = GetValueNode<EnvInstancePtr>(vnode);
    }
  }

  void Reset() {
    y_ = nullptr;
    env_ = nullptr;
  }

 private:
  AnfNodePtr y_{nullptr};
  EnvInstancePtr env_{nullptr};
};

// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}}
class AddEnvGetItem : public AnfVisitor {
 public:
  AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
  ~AddEnvGetItem() override = default;

  AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
    is_match_ = false;
    auto IsAddCNode = [](const AnfNodePtr &node) -> bool {
      return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast<CNodePtr>()->size() == 3;
    };
    AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node);

    if (!is_match_ || node->func_graph() == nullptr) {
      return nullptr;
    }

    // {prim::kPrimEnvGetItem, {...}, C, Z}
    auto cnode = node->cast<CNodePtr>();
    auto inp1 = cnode->input(1)->cast<CNodePtr>();
    auto c = cnode->input(2);
    auto z = cnode->input(3);

    // {prim::kPrimEnvAdd, X, Y}
    auto x = inp1->input(1);
    auto y = inp1->input(2);

    auto fg = node->func_graph();
    auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z});
    auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z});

    return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz});
  }

  void Visit(const AnfNodePtr &) override { is_match_ = true; }

 private:
  bool is_match_{false};
  ValuePtr PrimHyperAdd_;
};

// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z}
class EnvGetSetItem : public AnfVisitor {
 public:
  AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
    is_match_ = false;
    auto IsSetCNode = [](const AnfNodePtr &node) -> bool {
      if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) {
        return false;
      }

      // {prim::kPrimEnvSetItem, X, C1, Y}
      auto &inputs = node->cast<CNodePtr>()->inputs();
      if (inputs.size() != 4) {
        return false;
      }

      return IsValueNode<SymbolicKeyInstance>(inputs[2]);
    };
    AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);

    if (!is_match_ || node->func_graph() == nullptr) {
      return nullptr;
    }

    // {prim::kPrimEnvGetItem, {...}, C2, Z}
    auto cnode = node->cast<CNodePtr>();
    auto inp1 = cnode->input(1)->cast<CNodePtr>();
    auto key2 = cnode->input(2);
    auto c2 = GetValueNode<SymbolicKeyInstancePtr>(key2);
    auto default_v = cnode->input(3);

    // {prim::kPrimEnvSetItem, X, C1, Y}
    auto env = inp1->input(1);
    auto c1 = GetValueNode<SymbolicKeyInstancePtr>(inp1->input(2));
    auto last_set = inp1->input(3);

    if (*c1 == *c2) {
      return last_set;
    }

    while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
      // {prim::kPrimEnvSetItem, env, symbolickey, value}
      auto &inputs = env->cast<CNodePtr>()->inputs();
      if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
        MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance.";
      }

      env = inputs[1];
      last_set = inputs[3];
      auto symbolic_c1 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
      if (*symbolic_c1 == *c2) {
        return last_set;
      }
    }

    return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v});
  }

  void Visit(const AnfNodePtr &) override { is_match_ = true; }

 private:
  bool is_match_{false};
};

// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
class IncorporateEnvGetitem : public AnfVisitor {
 public:
  IncorporateEnvGetitem() : env_get_item_transform_() {}
  ~IncorporateEnvGetitem() override = default;

  AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
    is_match_ = false;
    auto IsGCNode = [](const AnfNodePtr &node) -> bool {
      auto cnode = node->cast<CNodePtr>();
      if (cnode == nullptr || cnode->size() < 1) {
        return false;
      }
      return IsValueNode<FuncGraph>(cnode->input(0));
    };
    AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);

    if (!is_match_) {
      return nullptr;
    }

    // {prim::kPrimEnvGetItem, {...}, C, Y}
    auto cnode = node->cast<CNodePtr>();
    auto inp1 = cnode->input(1)->cast<CNodePtr>();
    auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
    auto default_v = cnode->input(3);

    // {G, Xs}
    auto inputs = inp1->inputs();
    auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
    auto new_fg = env_get_item_transform_(fg, key, default_v);

    std::vector<AnfNodePtr> args;
    args.push_back(NewValueNode(new_fg));
    (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());

    return node->func_graph()->NewCNode(args);
  }

  void Visit(const AnfNodePtr &) override { is_match_ = true; }

 private:
  bool is_match_{false};
  internal::EnvGetitemTransform env_get_item_transform_;
};

// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
class IncorporateEnvGetitemSwitch : public AnfVisitor {
 public:
  IncorporateEnvGetitemSwitch() : env_get_item_transform_() {}
  ~IncorporateEnvGetitemSwitch() override = default;

  AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
    is_match_ = false;
    auto IsSwNode = [](const AnfNodePtr &node) -> bool {
      auto cnode = node->cast<CNodePtr>();
      if (cnode == nullptr || cnode->size() < 1) {
        return false;
      }

      return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch);
    };
    AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
    if (!is_match_ || node->func_graph() == nullptr) {
      return nullptr;
    }

    // {prim::kPrimEnvGetItem, {...}, C, Y}
    auto cnode = node->cast<CNodePtr>();
    auto inp1 = cnode->input(1)->cast<CNodePtr>();
    auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
    auto default_v = cnode->input(3);

    // {{prim::kPrimSwitch, X, G1, G2}, Xs}
    auto inputs = inp1->inputs();
    is_match_ = false;
    AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode<FuncGraph>, IsValueNode<FuncGraph>})(inputs[0]);
    if (!is_match_) {
      return nullptr;
    }

    // {prim::kPrimSwitch, X, G1, G2}
    auto sw = inputs[0]->cast<CNodePtr>();
    auto x = sw->input(1);
    auto g1 = GetValueNode<FuncGraphPtr>(sw->input(2));
    auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3));
    auto new_g1 = env_get_item_transform_(g1, key, default_v);
    auto new_g2 = env_get_item_transform_(g2, key, default_v);

    auto fg = node->func_graph();
    auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)});

    std::vector<AnfNodePtr> args{new_sw};
    (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());

    return fg->NewCNode(args);
  }

  void Visit(const AnfNodePtr &) override { is_match_ = true; }

 private:
  bool is_match_{false};
  internal::EnvGetitemTransform env_get_item_transform_;
};
}  // namespace irpass
}  // namespace opt
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
